from __future__ import print_function, division

import copy

import torch
import torch.nn.parallel
import torch.utils.data
import torchvision.utils as vutils
import numpy as np

from tensorboardX.writer import SummaryWriter
from torch.optim.adam import Adam
from typing import Callable, Dict, List, Union


def make_iterative_func(func: Callable) -> Callable:
    def wrapper(vars):
        if isinstance(vars, list):
            return [wrapper(x) for x in vars]
        elif isinstance(vars, tuple):
            return tuple([wrapper(x) for x in vars])
        elif isinstance(vars, dict):
            return {k: wrapper(v) for k, v in vars.items()}
        else:
            return func(vars)

    return wrapper


def make_nograd_func(func: Callable) -> Callable:
    def wrapper(*f_args, **f_kwargs):
        with torch.no_grad():
            ret = func(*f_args, **f_kwargs)
        return ret

    return wrapper


@make_iterative_func
def tensor2float(vars):
    if isinstance(vars, float):
        return vars
    elif isinstance(vars, torch.Tensor):
        return vars.data.item()
    else:
        raise NotImplementedError("invalid input type for tensor2float")


@make_iterative_func
def tensor2numpy(vars):
    if isinstance(vars, np.ndarray):
        return vars
    elif isinstance(vars, torch.Tensor):
        return vars.data.cpu().numpy()
    else:
        raise NotImplementedError("invalid input type for tensor2numpy")


@make_iterative_func
def check_allfloat(vars):
    assert isinstance(vars, float)


def save_scalars(
    logger: SummaryWriter,
    mode_tag: str,
    scalar_dict: Dict[str, Union[float, List[float]]],
    global_step: int,
) -> None:
    scalar_dict = tensor2float(scalar_dict)
    for tag, values in scalar_dict.items():
        if not isinstance(values, list) and not isinstance(values, tuple):
            values = [values]
        for idx, value in enumerate(values):
            scalar_name = "{}/{}".format(mode_tag, tag)
            scalar_name = scalar_name + "_" + str(idx)
            logger.add_scalar(scalar_name, value, global_step)


def save_images(logger, mode_tag, images_dict, global_step):
    images_dict = tensor2numpy(images_dict)
    for tag, values in images_dict.items():
        if not isinstance(values, list) and not isinstance(values, tuple):
            values = [values]
        for idx, value in enumerate(values):
            if len(value.shape) == 3:
                value = value[:, np.newaxis, :, :]
            value = value[:1]
            value = torch.from_numpy(value)

            image_name = "{}/{}".format(mode_tag, tag)
            if len(values) > 1:
                image_name = image_name + "_" + str(idx)
            logger.add_image(
                image_name,
                vutils.make_grid(
                    value, padding=0, nrow=1, normalize=True, scale_each=True
                ),
                global_step,
            )


def adjust_learning_rate(
    optimizer: Adam, epoch: int, base_lr: float, lrepochs: str
) -> None:
    splits = lrepochs.split(":")
    assert len(splits) == 2

    downscale_epochs = [int(eid_str) for eid_str in splits[0].split(",")]
    downscale_rate = float(splits[1])
    print(
        "downscale epochs: {}, downscale rate: {}".format(
            downscale_epochs, downscale_rate
        )
    )

    lr = base_lr
    for eid in downscale_epochs:
        if epoch >= eid:
            lr /= downscale_rate
        else:
            break
    print("setting learning rate to {}".format(lr))
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


class AverageMeter(object):
    def __init__(self):
        self.sum_value = 0.0
        self.count = 0

    def update(self, x):
        check_allfloat(x)
        self.sum_value += x
        self.count += 1

    def mean(self):
        return self.sum_value / self.count


class AverageMeterDict(object):
    def __init__(self) -> None:
        self.data = None
        self.count = 0

    def update(self, x: Dict[str, Union[float, List[float]]]) -> None:
        check_allfloat(x)
        self.count += 1
        if self.data is None:
            self.data = copy.deepcopy(x)
        else:
            for k1, v1 in x.items():
                if isinstance(v1, float):
                    self.data[k1] += v1
                elif isinstance(v1, tuple) or isinstance(v1, list):
                    for idx, v2 in enumerate(v1):
                        self.data[k1][idx] += v2
                else:
                    assert NotImplementedError(
                        "error input type for update AvgMeterDict"
                    )

    def mean(self) -> Dict[str, Union[float, List[float]]]:
        @make_iterative_func
        def get_mean(v):
            return v / float(self.count)

        return get_mean(self.data)
